package org.springblade.ocpp.conf;

import cn.hutool.core.collection.CollUtil;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;

import java.util.List;
import java.util.Map;

/**
 * @author lhb
 * @date 2024/9/10 上午11:49
 */
@Configuration
public class WebSocketHeaderConfig extends ServerEndpointConfig.Configurator {

	@Bean
	public ServerEndpointExporter serverEndpointExporter() {
		return new ServerEndpointExporter();
	}

	/**
	 * 通信文本消息和二进制缓存区大小
	 * 避免对接 第三方 报文过大时，Websocket 1009 错误
	 *
	 * @return
	 */
	@Bean
	public ServletServerContainerFactoryBean createWebSocketContainer() {
		ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean();
		// 在此处设置bufferSize
		container.setMaxTextMessageBufferSize(10240000);
		container.setMaxBinaryMessageBufferSize(10240000);
		container.setMaxSessionIdleTimeout(15 * 60000L);
		return container;
	}

	/**
	 * 建立握手时，连接前的操作
	 */
	@Override
	public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
		final Map<String, Object> userProperties = sec.getUserProperties();
		Map<String, List<String>> headers = request.getHeaders();
		List<String> remoteIp = headers.get("x-forwarded-for");
		List<String> authorzation = headers.get("Authorzation");
		if (CollUtil.isNotEmpty(authorzation)) {
			userProperties.put("Authorzation", authorzation.get(0));
		}
		if (CollUtil.isNotEmpty(remoteIp)) {
			userProperties.put("x-forwarded-for", remoteIp.get(0));
		}
		//response.getHeaders().put("", "");
		super.modifyHandshake(sec, request, response);
	}

	/**
	 * 初始化端点对象,也就是被@ServerEndpoint所标注的对象
	 */
	@Override
	public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
		return super.getEndpointInstance(clazz);
	}
}
